from distutils.command.clean import clean
import numpy as np
from utils import train_model, test_model, get_needed_dirs
from utils import compute_loss

import pickle 
import argparse
from data_utils import get_dataset
import argparse
import os
import sys
from scipy.io import savemat, loadmat
import scipy.sparse as sparse

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='mnist_17',help="options: mnist_17,dogfish,enron,mnist_38,mnist_69,mnist_49,cifar10_05")
parser.add_argument('--model_type',default='svm',help='victim model type: SVM or rlogistic regression')
parser.add_argument('--weight_decay',default=0.09, type=float, help='weight decay for regularizers')
parser.add_argument('--use_train',action="store_true", help='test data will be leveraged in the attack process')
parser.add_argument('--original',action="store_true", help='we will leverage the original target model generation process')
parser.add_argument('--save_to_mat',action="store_true", help='covert the python target models into mat format')
parser.add_argument('--rand_seed',default=1234, type=int, help='seed for random number generator')
parser.add_argument('--check_transfer',action="store_true", help='check the transferability of different models')

args = parser.parse_args()

if args.dataset == 'imdb':
    args.weight_decay = 0.01

def generate_target(quantile_tape,rep_tape,y_list,X_train,Y_train,X_test,Y_test,target_errors,args,use_test=False,epoch=None):
    model = train_model(X_train,Y_train,args)
    orig_theta = model.coef_.reshape(-1)
    orig_bias = model.intercept_

    used_model_type = 'svm'

    if use_test:
        X_use,Y_use = X_test,Y_test
        data_type = 'Test'
    else:
        data_type = 'Train'
        X_use,Y_use = X_train,Y_train

    all_data_info = {}
    for target_error in target_errors:
        all_data_info[target_error] = {}
        all_data_info[target_error]['best_train_loss_w_reg'] = 1e10
        all_data_info[target_error]['best_train_loss'] = 1e10

    success_flag = np.zeros(len(target_errors),dtype=bool)
    # note that, if we do not use test data, we will automatically use train data when passed as params
    ym = (-1)*Y_use
    clean_margins = Y_use*(X_use.dot(orig_theta) + orig_bias)
    # clean_margins = compute_loss(args.model_type,X_use,Y_use,orig_theta,orig_bias,margin_only=True)
    for loss_quantile in quantile_tape:
        for tar_rep in rep_tape:
            print(" ----- Loss Quantile {} and Repetition Number {} ------".format(loss_quantile, tar_rep))
            X_tar = []
            Y_tar = []
            margin_thresh = np.quantile(clean_margins, loss_quantile)
            for i in range(len(y_list)):
                active_cur = np.logical_and(Y_use == y_list[i],clean_margins < margin_thresh)
                X_tar_cur = X_use[active_cur,:]
                y_tar_cur = ym[active_cur]
                # y_orig_cur = Y_test[active_cur]
                X_tar.append(X_tar_cur)
                Y_tar.append(y_tar_cur)
                # Y_orig = Y_orig.append(y_orig_cur)
            
            if sparse.issparse(X_train):
                X_tar = sparse.vstack(X_tar,format='csr')
                X_tar = sparse.csr_matrix.toarray(X_tar)
            else:
                X_tar = np.concatenate(X_tar, axis=0)
            Y_tar = np.concatenate(Y_tar, axis=0)
            # repeat points
            X_tar = np.repeat(X_tar, tar_rep, axis=0)
            Y_tar = np.repeat(Y_tar, tar_rep, axis=0) 
            if sparse.issparse(X_train):
                X_tar = sparse.csr_matrix(X_tar)
                X_train_p = sparse.vstack((X_train,X_tar),format='csr')
            else:
                X_train_p = np.concatenate((X_train,X_tar),axis = 0)
            Y_train_p = np.concatenate((Y_train,Y_tar),axis = 0)
            # build another model for poisoned points
            model_p = train_model(X_train_p,Y_train_p,args)

            target_theta, target_bias = model_p.coef_.reshape(-1), model_p.intercept_
            # train margin and loss
            train_loss_p = np.mean(compute_loss(used_model_type,X_train_p,Y_train_p,target_theta,target_bias,margin_only=False))
            train_acc_p = model_p.score(X_train_p,Y_train_p)
            # print("total train acc:{:.4f}, total_train loss:{:.4f}".format(train_acc_p,train_loss_p))
            
            # also record the poisoned train loss and acc, just for record
            train_loss_poison = np.mean(compute_loss(used_model_type,X_tar,Y_tar,target_theta,target_bias,margin_only=False))
            train_acc_poison = model_p.score(X_tar,Y_tar)

            train_loss = np.mean(compute_loss(used_model_type,X_train,Y_train,target_theta,target_bias,margin_only=False))
            # clean margins and loss, uses regularized loss
            train_loss_w_reg = train_loss + (args.weight_decay/2) * np.linalg.norm(target_theta)**2
            train_acc = model_p.score(X_train,Y_train)
            train_err = 1-train_acc
            if not use_test:
                print("clean train acc (with reg):{:.4f}, clean train loss:{:.4f}".format(train_acc,train_loss))
            # test margins and loss 
            test_loss = np.mean(compute_loss(used_model_type,X_test,Y_test,target_theta,target_bias,margin_only=False))
            test_acc = model_p.score(X_test,Y_test)
            test_err = 1 - test_acc
            if use_test:
                print("clean test acc:{:.4f}, clean test loss:{:.4f}".format(test_acc,test_loss))

            # update the best fit target model
            for _i in range(len(target_errors)-1):
                curr_error = target_errors[_i]
                if _i < len(target_errors)-1:
                    next_error = target_errors[_i+1]
                else:
                    next_error = None
                # select the best fit target classifier
                if use_test:
                    compare_metric = test_err
                else:
                    compare_metric = train_err
                if next_error is None:
                    continue_cond = True
                else:
                    continue_cond = compare_metric <= next_error

                if compare_metric >= curr_error and continue_cond:
                    print(compare_metric,curr_error,next_error)
                    # if train_loss_w_reg <= all_data_info[target_error]['best_train_loss_w_reg']:
                    if train_loss <= all_data_info[curr_error]['best_train_loss']:
                        print("----- Update Target Error {}, Next Target Error {}: Loss Quantile {} and Repetition Number {} ------".format(curr_error,\
                            next_error,loss_quantile, tar_rep))
                        print("{} Error: {}, Train Loss: {}, Train Loss W/ Reg: {}".format(data_type,compare_metric,train_loss,train_loss_w_reg))
                        all_data_info[curr_error]['best_theta'] = target_theta
                        all_data_info[curr_error]['best_bias'] = target_bias
                        all_data_info[curr_error]['best_train_loss_w_reg'] = train_loss_w_reg
                        all_data_info[curr_error]['best_train_loss'] = train_loss
                        all_data_info[curr_error]['best_train_error'] = train_err
                        all_data_info[curr_error]['best_test_loss'] = test_loss
                        all_data_info[curr_error]['best_test_error'] = test_err
                        all_data_info[curr_error]['best_poison_loss'] = train_loss_poison
                        all_data_info[curr_error]['best_poison_error'] = 1-train_acc_poison
                        all_data_info[curr_error]['poison_num'] = X_tar.shape[0]
                        all_data_info[curr_error]['X_poison'] = X_tar
                        all_data_info[curr_error]['Y_poison'] = Y_tar
                        success_flag[_i] = 1
                if _i < (len(target_errors)-1) and not success_flag[_i] and success_flag[_i+1]:
                    print("cannot generate target error {} properly, reevaluate the label flipping attack!".format(curr_error))
                    
            # updating clean_margins will produce models with even lower loss on clean data, we directly update this to make sure 
            # we are generating better target models
            if not args.original:
                clean_margins = Y_use*(X_use.dot(target_theta) + target_bias)
                # clean_margins = compute_loss(args.model_type,X_use,Y_use,target_theta,target_bias,margin_only=True)

    if np.sum(success_flag) < len(success_flag)-1:
        print("failed to generate all desired target models!")

    if not args.use_train:
        train_or_test_data = 'use_test'
    else:
        train_or_test_data = 'use_train'

    dataset = args.dataset
    if args.dataset == '2d_toy':
        target_model_dir = 'files/target_classifiers/{}/{}/{}_D/{}/Sep_{}_Flip_{}_Addi_{}/{}/{}'.format('2d_toy',args.model_type,\
            X_train.shape[1],args.n_samples,args.class_sep,args.flip_y,args.addi_search_space,args.weight_decay,train_or_test_data)
    else:
        if dataset == 'cifar10_trial':
            dataset = '{}_{}'.format(dataset,epoch)
        target_model_dir = 'files/target_classifiers/{}/{}/{}/{}'.format(dataset,args.model_type,\
            args.weight_decay,train_or_test_data)
    if not os.path.isdir(target_model_dir):
        os.makedirs(target_model_dir)

    # prepare to save the target models also in mat form to better evaluate original min-max attack
    improve_types = ['improved']
    save_dict = {}
    save_dict['thetas'] = []
    save_dict['biases'] = []
    save_dict['train_losses'] = []
    save_dict['test_errors'] = []
    save_dict['quantiles'] = [0.1]*len(target_errors)*len(improve_types)
    save_dict['reps'] = [1]*len(target_errors)*len(improve_types)

    for _i in range(len(target_errors)-1):
        target_error = target_errors[_i]
        if success_flag[_i]:
            print("--- best target classifier with target error rate {} ---- ".format(target_error))
            print("Num of Poisons Used:",all_data_info[target_error]['poison_num'])
            print("Train Error of best theta",all_data_info[target_error]['best_train_error'])
            print("Train Loss of best theta",all_data_info[target_error]['best_train_loss'])
            print("Test Error of best theta:",all_data_info[target_error]['best_test_error'])
            print("Test Loss of best theta:",all_data_info[target_error]['best_test_loss'])
            print("Poison Train Error of best theta:",all_data_info[target_error]['best_poison_error'])
            print("Poison Train Loss of best theta:",all_data_info[target_error]['best_poison_loss'])
            if not args.original:
                file_all = open('{}/improved_best_theta_whole_err-{}'.format(target_model_dir,target_error), 'wb')
            else:
                file_all = open('{}/orig_best_theta_whole_err-{}'.format(target_model_dir,target_error), 'wb')
            # dump information to that file
            pickle.dump(all_data_info[target_error], file_all,protocol=2)
            file_all.close()

            # save the target model info in mat format
            best_target_theta = all_data_info[target_error]['best_theta']
            best_target_bias = all_data_info[target_error]['best_bias']
            save_dict['thetas'].append(best_target_theta)
            save_dict['biases'].append(best_target_bias)
            tar_model_train_loss = all_data_info[target_error]['best_train_loss'] 
            tar_model_train_err = all_data_info[target_error]['best_train_error'] 
            tar_model_test_loss = all_data_info[target_error]['best_test_loss'] 
            tar_model_test_err = all_data_info[target_error]['best_test_error'] 
            save_dict['train_losses'].append(tar_model_train_loss)
            save_dict['test_errors'].append(tar_model_test_err)
            print("---- Saved into Mat Target Model Error-{} Info----".format(target_error)) 
            print("Train Error: {:.5f}, Train Loss: {:.5f}, Test Error: {:.5f}, Test Loss: {:.5f}".format(tar_model_train_err,tar_model_train_loss,\
                tar_model_test_err,tar_model_test_loss))
        else:
            print("failed to generate target model of error {}".format(target_error))
    
    if False:
        # save into .mat format
        mat_save_dir = 'files/target_classifiers/{}/{}/{}/{}/py_to_mat'.format(args.dataset,args.model_type,\
                args.weight_decay,train_or_test_data)
        if not os.path.isdir(mat_save_dir):
            os.makedirs(mat_save_dir)
        mat_save_fname = '{}/{}_thetas_with_bias_exact_decay_{}_py_to_mat.mat'.format(mat_save_dir,args.dataset,int(round(100*args.weight_decay)))
        savemat(mat_save_fname, save_dict)

    print("Generated target errors are:",target_errors[success_flag])
    fname = '{}/generated_errors.npy'.format(target_model_dir)
    np.save(fname,target_errors[success_flag])

def generate_target_model(args):
    if args.dataset == 'imdb':
        args.weight_decay = 0.01

    if args.dataset == 'mnist_17':
        quantile_tape = [0.01,0.02,0.03,0.04,0.05,0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.6]
        rep_tape = [1, 2, 3, 5, 8, 10, 12, 15, 20, 25, 30, 50, 60]
        target_errors = [0.03,0.04,0.05,0.06,0.07,0.09,0.11,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.6]
    elif args.dataset == 'dogfish':
        target_errors = 0.01 * (np.arange(1,31))
        quantile_tape = [0.01,0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55,0.6, 0.7, 0.8]
        rep_tape = [1, 2, 3, 4, 5, 6, 7, 8,10,12,15] #, 8, 10, 12, 15, 20, 25, 30,50,60,80,100]
        if not args.use_train:
            target_errors = [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.8,0.9]
        else:
            target_errors = [0.03,0.05,0.07,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.8,0.9]
    elif args.dataset == '2d_toy':
        target_errors = 0.1*np.arange(10)
    elif args.dataset == 'adult':
        quantile_tape = [0.01,0.02,0.03,0.04,0.05,0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.6]
        rep_tape = [1, 2, 3, 4, 5, 6, 7, 8,10,12,15,20,25,30,50,60] 
        target_errors = [0.23,0.25,0.3,0.33,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7]
    elif args.dataset in ['mnist_38','mnist_69','mnist_49']:
        quantile_tape = [0.01,0.02,0.03,0.04,0.05,0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.6]
        rep_tape = [1, 2, 3, 5, 8, 10, 12, 15, 20, 25, 30, 50, 60]
        # last error will not be used, just as a stoppping conditon
        target_errors = [0.03,0.04,0.05,0.06,0.07,0.09,0.11,0.12,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6]
    elif args.dataset in ['enron','imdb','filtered_enron']:
        quantile_tape = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.6, 0.7, 0.8]
        rep_tape = [1, 2, 3, 5, 8, 12, 18, 25, 33, 40]
        target_errors = [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.8]
    elif args.dataset in ['cifar10_05','cifar10_14']:
        quantile_tape = [0.01,0.02,0.03,0.04,0.05,0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.6]
        rep_tape = [1, 2, 3, 4, 5, 6, 7, 8,10,12,15,20,25,30,50,60] 
        target_errors = [0.03,0.05,0.07,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.8,0.9]
    elif args.dataset == 'cifar10_trial':
        # we use the datasets to get a general sense of attack success
        quantile_tape = [0.01,0.02,0.03,0.04,0.05,0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.6]
        rep_tape = [1, 2, 3, 4, 5, 6, 7, 8,10,12,15,20,25,30,50,60] 
        target_errors = [0.03,0.05,0.07,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.6,0.7]        

    if args.dataset != 'cifar10_trial':
        key_epochs = [0]
    else:
        key_epochs = [-1] # [50,90,100,120]
    for key_epoch in key_epochs:
        X_train,Y_train,X_test,Y_test,x_lims = get_dataset(args,key_epoch)
        use_test = not args.use_train
        # train clean model and report results
        clean_model = train_model(X_train,Y_train,args)
        if use_test:
            clean_error = 1-clean_model.score(X_test,Y_test)
        else:
            clean_error = 1-clean_model.score(X_train,Y_train)

        target_errors = np.delete(target_errors,target_errors <= clean_error)

        y_list = [-1,1]
        generate_target(quantile_tape,rep_tape,y_list,X_train,Y_train,X_test,Y_test,target_errors,args,use_test=use_test,epoch=key_epoch)

generate_target_model(args)
